from abc import ABC, abstractmethod
from typing import Dict, NamedTuple, Optional, Union
import haiku as hk
import optax
import chex
import jax

from brax import envs
from brax.training import networks
from brax.training import normalization

from model_based_rl.replay_buffer import *
from model_based_rl.types import *
from model_based_rl.utils import nonlinearity


class DynamicsModel(ABC):

  def __init__(
    self,
    env_name: str,
    obs_size: int,
    acts_size: int,
    action_repeat: int,
    normalizer: normalization.Normalizer,
    termination_fn: Callable[[Observation, Action, NextObservation], jnp.ndarray],
    **kwargs,
  ):
    self.name = env_name
    self.act_size = acts_size
    self.obs_size = obs_size
    self.normalizer = normalizer
    self.action_repeat = action_repeat
    self.termination_fn = termination_fn
    self.grad_loss = jax.value_and_grad(self.loss, has_aux=True)

    self.core_env = envs.create(
      env_name=env_name,
      auto_reset=False,
      episode_length=None,
      batch_size=1,
      action_repeat=action_repeat,
      )

    assert self.act_size == self.core_env.action_size
    assert self.obs_size == self.core_env.observation_size
    self._step = jax.jit(self.core_env.step)

  def init(
      self,
      key: PRNGKey
  ) -> Tuple[Params, optax.OptState]:
    return {}, {}

  def loss(
      self,
      dynamics_params: Params,
      normalizer_params: Params,
      transitions: Transition,
  ) -> Tuple[jnp.ndarray, Metrics]:
      return jnp.array(0.0), {}

  def update_step(
      self,
      training_state: TrainingState,
      transitions: Transition
  ) -> Tuple[TrainingState, Metrics]:
    return training_state, {}

  def train(
      self,
      training_state: TrainingState,
      replay_buffer: ReplayBuffer,
  ) -> Tuple[TrainingState, Metrics]:
    return training_state, {}

  def apply(
      self,
      params: Tuple[Params, Params],
      obs: Observation, acts: Action
  ) -> jnp.ndarray:
    raise NotImplementedError

  def step(
      self,
      params: Tuple[Params, Params],
      key: PRNGKey,
      state: brax.QP,
      observation: jnp.array,
      norm_observation: Union[Observation, None],
      acts: Action,
  ) -> Transition:
    chex.assert_shape(observation, (None, self.obs_size))
    chex.assert_shape(acts, (None, self.act_size))
    (dynamics_params, normalizer_params) = params
    num_batch = observation.shape[0]

    if norm_observation is None:
      norm_observation = self.normalizer.apply(normalizer_params, observation)

    env_state = brax.envs.State(
      qp=state,
      obs=observation,
      reward=jnp.zeros((num_batch,)),
      done=jnp.zeros((num_batch,)),
      metrics={},
      info={}
    )

    next_env_state = self._step(env_state, acts)
    next_state = next_env_state.qp
    next_observation = next_env_state.obs
    next_norm_observation = self.normalizer.apply(normalizer_params, next_observation)

    return Transition(
      s_tm1=state,
      o_tm1=observation,
      norm_o_tm1=norm_observation,
      a_tm1=acts,
      s_t=next_state,
      o_t=next_observation,
      norm_o_t=next_norm_observation,
      r_t=next_env_state.reward,
      d_t=jnp.ones_like(next_env_state.done),
    )

  def rollout(
      self,
      key: PRNGKey,
      state: brax.QP,
      observation: Observation,
      norm_observation: Union[Observation, None],
      policy_params: Tuple[Params, Params],
      dynamics_params: Tuple[Params, Params],
      policy: Callable[[Tuple[Params, Params], PRNGKey, Observation], Action],
      num_steps: Optional[int] = 1,
  ) -> Tuple[Transition, jnp.ndarray]:
    chex.assert_rank(observation, 2)

    if norm_observation is None:
      normalizer_params, _ = dynamics_params
      norm_observation = self.normalizer.apply(normalizer_params, observation)

    def step(input, key):
      state, obs, norm_obs = input
      key_policy, key_dynamics = jax.random.split(key)
      action, log_prob_action = policy(policy_params, key_policy, obs)
      transition = self.step(dynamics_params, key_dynamics, state, obs, norm_obs, action)
      return (transition.s_t, transition.o_t, transition.norm_o_t), (transition, log_prob_action, key_policy, key_dynamics)

    # Transition with elements: traj.o_t.shape = (k_steps, n_starting_states, -1)
    init = (state, observation, norm_observation)
    _, (trajectories, log_prob_action, key_policy, key_dynamics) = jax.lax.scan(step, init, jax.random.split(key, num_steps))

    # Update the (not) done flag with the TRUE termination function:
    done = self.termination_fn(trajectories.o_tm1, trajectories.a_tm1, trajectories.o_t)
    trajectories = trajectories.replace(d_t = 1 - done)
    return trajectories, log_prob_action
